import numpy as np
from scipy.io import loadmat
from pathlib import Path
from PIL import Image
from demo2_mainbody import LegendreDecomposition1
from revise_lgd import LegendreDecomposition


data = loadmat('/tmp/pycharm_project_272/realdata/Butterfly.mat')
image_data = data['img']
if not isinstance(image_data, np.ndarray):
    image_data = np.array(image_data)

print("max=", image_data.max())
print("min=", image_data.min())


print(f"Image data type: {type(image_data)}")
print(f"Image data shape: {image_data.shape}")

image_3channels = image_data[249:330, 249:330, :9]


reshaped_image = image_3channels.flatten()
print("Processed image shape:", reshaped_image.shape)
print("max=", reshaped_image.max())
print("min=", reshaped_image.min())


def out_put_coordinates(tensor):
    indices = np.argwhere(tensor == 1)
    indices1 = np.argwhere(tensor == 0)


    coordinates = [tuple(index) for index in indices]


    coordinates1 = coordinates.copy()
    if len(indices1) >= 2:
        coordinates1.append(tuple(indices1[1]))

    coordinates2 = [tuple(index1) for index1 in indices1]

    del coordinates2[0]
    return coordinates, coordinates1, coordinates2


def change_parameter(tensor, k):
    dims = tensor.shape
    binary_tensor = np.ones(dims, dtype=int)


    if k >= len(dims):
        raise ValueError("k must be strictly less than the tensor's dimension.")


    binary_tensor[(0,) * len(dims)] = 0

   
    it = np.nditer(binary_tensor, flags=['multi_index'])
    while not it.finished:
        index = it.multi_index
       
        non_zero_count = sum(1 for i in index if i != 0)

        
        if non_zero_count > k or non_zero_count == 0:
            binary_tensor[index] = 0
        it.iternext()

    return binary_tensor


def calculate_s(P):
    # Step 1: Calculate size_P as the product of all dimensions of P
    size_P = np.prod(P.shape)
    # Step 2: Calculate the sum of all elements in the tensor
    sum_P = np.sum(P)

    # Step 3: Find the minimum element in the tensor
    min_P = np.min(P)

    # Step 4: Calculate s using the formula
    s = np.log(sum_P / min_P) / np.log(size_P)

    return s


def is_c_in_range(P, s, c, alpha):
   
    dims = P.shape
    d = len(dims)

 
    size_P = np.prod(dims)

  
    log_product_Ij = np.log(size_P)

  
    left_bounds = []
    right_bounds = []

    for i in range(d):
        I_i = dims[i]

       
        prod_dims_plus_1_excluding_i = np.prod([(dims[j] + 1) for j in range(d) if j != i])

   
        left_bound = (-2 ** d * ((s - 1) * d + 1) * log_product_Ij) / (
                ((1 - 1 / alpha) * I_i + 1) * prod_dims_plus_1_excluding_i)
        left_bounds.append(left_bound)

     
        right_bound = (2 ** d * ((s - 1) * d + 1) * log_product_Ij) / (
                ((1 - 1 / alpha) * I_i + 1) * prod_dims_plus_1_excluding_i)
        right_bounds.append(right_bound)

   
    max_lower_bound = max(left_bounds)
    min_upper_bound = min(right_bounds)

 

    if max_lower_bound <= c <= min_upper_bound:
        print(f'c = {c} is within the range: [{max_lower_bound}, {min_upper_bound}]')
        return min_upper_bound, 1
    else:
        print(f'c = {c} is outside the range: [{max_lower_bound}, {min_upper_bound}]')
        return min_upper_bound, 0


for k in range(2, 6):
    results_c = []
    results_upperbound = []
    rights_line = []
    results_d = []
    for d in range(k+1, 11):

        print('d,k=', d, k)
        num=3**d
        if num > reshaped_image.size:
            raise ValueError(f"Cannot extract {num} elements: exceeds total tensor size {reshaped_image.size}")
        sub_tensor = reshaped_image[:num]

    
        P = sub_tensor.reshape([3] * k)



        print(P.shape)
       
        min_element = np.min(P)

      
        # s = np.log(sum_elements / min_element) * (1/(np.log(2) * d))
        s = calculate_s(P)
        print('s=', s)

        binary_tensor = change_parameter(P, k)

        # print('binary_tensor=', binary_tensor)
        coordinates, coordinates1, coordinates_complement = out_put_coordinates(binary_tensor)
        # print('coordinates=', coordinates)
        ld_ori = LegendreDecomposition(solver='ng', max_iter=50000, verbose=0, learning_rate=0.001)  
        reconst_tensor_ori = ld_ori.fit_transform(P, coordinates)

        print('Reconstruction error(RSE): {}'.format(ld_ori.reconstruction_err_))

        ld_imp = LegendreDecomposition1(solver='ng', max_iter=500, verbose=0, learning_rate=0.001)  
        reconst_tensor_imp = ld_imp.fit_transform(P, coordinates, coordinates1, coordinates_complement, ld_ori.theta)

        print('Reconstruction error(RSE): {}'.format(ld_imp.reconstruction_err_))
        results_c.append(ld_imp.c)
        print('c=', ld_imp.c)

        upperbounds, rights = is_c_in_range(P, s, ld_imp.c, d)
        results_upperbound.append(upperbounds)
        rights_line.append(rights)
        results_d.append(d)

        folder = Path(f'size_test/butterfly_k={k}_dim_increase')
        folder.mkdir(parents=True, exist_ok=True)  

       
        with open(folder / 'results_upperbound.txt', 'w') as file:
            for result in results_upperbound:
                file.write(str(result) + '\n') 

       
        with open(folder / 'rights_line_append.txt', 'w') as file:
            for result in rights_line:
                file.write(str(result) + '\n')  

     
        with open(folder / 'results_c.txt', 'w') as file:
            for result in results_c:
                file.write(str(result) + '\n') 

        
        with open(folder / 'results_d.txt', 'w') as file:
            for result in results_d:
                file.write(str(result) + '\n')  

